import jsonlines
from tqdm import tqdm
import sys
import re
import random
import os


def parse_options(options):
    parts = re.split(pattern, options)[1:]
    parts = [f"({chr(base_index + i)}) {p.strip()}" for i, p in enumerate(parts)]
    return parts


def transfer_examples(examples):
    results = []
    for e in examples:
        Q = e['Q'] if 'Q' in e else e['question']
        O = e['O'] if 'O' in e else e['options']
        A = e['A'] if 'A' in e else e['answer']
        try:
            E = e['E'] if 'E' in e else e['explanation']
        except:
            E = ""
        results.append({"Q": Q, "O": O, "A": A, "E": E})
    return results


path = sys.argv[1]
pattern = re.compile(r"\([A-Z]\)")
base_index = ord('A')

data = [d for d in jsonlines.open(path, "r")]
out_path = path.replace('_kl', '')
if os.path.exists(out_path):
    thre = 200 if "ecare" in path else 50
    fo = jsonlines.open(out_path.replace(".jsonl", "_part2.jsonl"), "w")
    facts = [d['fact'] for d in jsonlines.open(out_path, "r")]
else:
    thre = 600 if "ecare" in path else 150
    fo = jsonlines.open(out_path, "w")
    facts = []

passive, easy, normal, hard, unfair = [], [], [], [], []

fact2score, fact2examples = {}, {}
for example in tqdm(data):
    knowledge = example['K']
    for fact in knowledge:
        if fact in facts:
            continue
        if fact not in fact2examples:
            fact2examples[fact] = [example]
        else:
            fact2examples[fact].append(example)

        if fact not in fact2score:
            fact2score[fact] = knowledge[fact]
        else:
            if knowledge[fact] > fact2score[fact]:
                fact2score[fact] = knowledge[fact]
            else:
                fact2score[fact] = fact2score[fact]


for fact in fact2score:
    if fact2score[fact] < 0.1:
        passive.append(fact)
    elif 0.1 <= fact2score[fact] < 0.4:
        easy.append(fact)
    elif 0.4 <= fact2score[fact] < 0.7:
        normal.append(fact)
    elif 0.7 <= fact2score[fact] < 1.0:
        hard.append(fact)
    else:
        unfair.append(fact)


unfair = random.sample(unfair, min(thre, len(unfair)))
if len(unfair) < thre:
    size = thre - len(unfair)
    unfair += hard[-size:]
    hard = hard[:-size]
hard = random.sample(hard, min(thre, len(hard)))
if len(hard) < thre:
    size = thre - len(hard)
    hard += normal[-size:]
    normal = normal[:-size]
normal = random.sample(normal, min(thre, len(normal)))
if len(normal) < thre:
    size = thre - len(normal)
    normal += easy[-size:]
    easy = easy[:-size]
easy = random.sample(easy, min(thre, len(easy)))

print(f"[Esay]: {len(easy)}")
print(f"[Normal]: {len(normal)}")
print(f"[Hard]: {len(hard)}")
print(f"[Unfair]: {len(unfair)}")

easy = sorted(easy, key=lambda x: fact2score[x], reverse=False)
normal = sorted(normal, key=lambda x: fact2score[x], reverse=False)
hard = sorted(hard, key=lambda x: fact2score[x], reverse=False)
unfair = sorted(unfair, key=lambda x: fact2score[x], reverse=False)

for fact in easy:
    fo.write({"fact": fact, "score": fact2score[fact], "level": "easy", "examples": transfer_examples([e for e in fact2examples[fact] if e['K'][fact] >= 0.1])})

for fact in normal:
    fo.write({"fact": fact, "score": fact2score[fact], "level": "normal", "examples": transfer_examples([e for e in fact2examples[fact] if e['K'][fact] >= 0.1])})

for fact in hard:
    fo.write({"fact": fact, "score": fact2score[fact], "level": "hard", "examples": transfer_examples([e for e in fact2examples[fact] if e['K'][fact] >= 0.1])})

for fact in unfair:
    fo.write({"fact": fact, "score": fact2score[fact], "level": "unfair", "examples": transfer_examples([e for e in fact2examples[fact] if e['K'][fact] >= 0.1])})

fo.close()





